import torch
from torch import nn,Tensor
from typing import Optional,List,Tuple
from torch.onnx.symbolic_helper import parse_args

class MatMulInteger(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x:torch.Tensor,weight_t:torch.Tensor):               
        res = torch.matmul(x.to(dtype=torch.float32),weight_t.to(torch.float32))
        # res=torch.matmul(x.to(dtype=torch.int32,device="cpu") , # torch不支持CUDA上的int8矩阵乘
        #                             weight_t.to(dtype=torch.int32,device="cpu")).to(x.device)
        return res

    @staticmethod
    @parse_args("v","v")
    def symbolic(g:torch._C.Graph, x:torch.Tensor,weight_t:torch.Tensor):
        return g.op("MatMulInteger", x,weight_t)

matmulInteger = MatMulInteger.apply

def qMatmul(x_q:Tensor,x_max:Tensor,weight_q:Tensor,w_max:Tensor,dtype):
    res_q = matmulInteger(x_q , weight_q)
    mx = nn.functional.linear(x_max.unsqueeze(-1),w_max.unsqueeze(-1))
    res = torch.mul(res_q.to(device=mx.device,dtype=torch.float32), mx.to(torch.float32) ).to(dtype=dtype)  
    return res

def quantize_mat(mat:Tensor)-> Tuple[Tensor,Tensor]:
    max_val = (torch.max(torch.abs(mat),dim=-1)[0] / 127.0).to(dtype=torch.float16)
    mat =  (mat / max_val[...,None]).to(dtype=torch.int8)
    return mat, max_val


class W8X8Linear(nn.Module):
    def __init__(self, ori_w:Tensor, bias: Optional[Tensor] = None,act_max:Optional[Tensor] = None,alpha=32):
        super().__init__()
        self.bias = None if bias is None else bias.detach()
        self.dtype = ori_w.dtype
        self.alpha = alpha
        self.scales = None
        self.weight_q,self.max_val = quantize_mat(ori_w.detach())
        self.weight_q = nn.Parameter(self.weight_q.t(),requires_grad=False)
        self.max_val = nn.Parameter(self.max_val,requires_grad=False)

    def forward(self,x:Tensor) -> Tensor:
        if self.scales is not None:
            x = x.div(self.scales)
        x_q,x_max = quantize_mat(x)
        res = qMatmul(x_q,x_max,self.weight_q,self.max_val,x.dtype)
        if self.bias is not None:
            res = res + self.bias
        return res
quant_cls = {
    "W8X8":W8X8Linear
}
def replace_linear_modules(module:nn.Module,prefix:str,act_scales,cfg):
    for name, child in module.named_children():
        prefix_next = (prefix + '.' + name) if prefix != '' else name
        if isinstance(child, nn.Linear) and name in cfg:
            act_scale = None if act_scales is None or 'act_scale' not in cfg[name] else act_scales[prefix_next]
            alpha = 128 if 'alpha' not in cfg[name] else cfg[name]['alpha']
            setattr(module, name,quant_cls[cfg[name]['type']]
                    (child.weight,child.bias,act_max=act_scale,alpha=alpha))
        else:
            replace_linear_modules(child,prefix_next,act_scales,cfg)
            
def quantize(model:nn.Module,smooth:bool=False,act_scales_path:Optional[str]=None,cfg={}):
    act_scales = None
    if 'act_scales_path' in cfg:
        act_scales = torch.load(cfg['act_scales_path'])
        if 'smooth' in cfg:
            from smooth import smooth_lm
            smooth_lm(model, act_scales, 0.85)
    replace_linear_modules(model,'',act_scales,cfg)